"""
Phase 2: Rate Decomposition
============================
Replicates asset_returns Z₁ → 10y bond baseline (~43.7**).
Tests: Z → EMBI/sovereign spread, Z → convenience yield proxy.
Decomposes safe rate into general level + convenience premium.
Subsamples: safe issuers vs non-safe; OECD vs non-OECD.

Output: table2_rate_decomposition.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_assets")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

CONTROLS = ['rgdp_growth', 'inflation', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [dep_var] + regressors
    available = [c for c in cols if c in df.columns]
    missing = set(cols) - set(available)
    if missing:
        print(f"  [{label}] Missing columns: {missing}")
        regressors = [r for r in regressors if r in df.columns]
        if dep_var not in df.columns:
            print(f"  [{label}] Dep var {dep_var} missing — skipping")
            return None

    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None

    names = feature_names or regressors
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)

    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, "
          f"R²={gls.r_squared:.4f}, rho={gls.rho:.3f}")

    results = {
        'label': label,
        'dep_var': dep_var,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<25} {gls.beta[i]:>8.4f} ({gls.se[i]:.4f}) {sig}")

    return results


def main():
    print("=" * 70)
    print("PHASE 2: Rate Decomposition")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "safe_asset_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in CONTROLS if c in df.columns]

    # ================================================================
    # SECTION A: Replicate asset_returns baseline
    # ================================================================
    print("\n" + "─" * 60)
    print("A. Replicate Asset Returns Baseline")
    print("─" * 60)

    # M1: Z → real 10y bond (should replicate ~43.7**)
    r = run_model(df, 'real_bond_10y', demo_vars + controls,
                  "M1: Z → real 10y (baseline)", demo_vars + controls)
    if r: all_results.append(r)

    # M2: Z → real 3m short rate
    r = run_model(df, 'real_short_3m', demo_vars + controls,
                  "M2: Z → real 3m", demo_vars + controls)
    if r: all_results.append(r)

    # M3: Z → term spread
    r = run_model(df, 'term_spread', demo_vars + controls,
                  "M3: Z → term spread", demo_vars + controls)
    if r: all_results.append(r)

    # ================================================================
    # SECTION B: Safe vs Non-Safe Issuer Subsamples
    # ================================================================
    print("\n" + "─" * 60)
    print("B. Safe vs Non-Safe Issuer Subsamples")
    print("─" * 60)

    safe = df[df['safe_issuer'] == 1].copy()
    nonsafe = df[df['safe_issuer'] == 0].copy()
    print(f"  Safe issuers: {safe['iso3'].nunique()} countries, {len(safe):,} obs")
    print(f"  Non-safe: {nonsafe['iso3'].nunique()} countries, {len(nonsafe):,} obs")

    # M4a: Z → 10y bond, safe issuers only
    r = run_model(safe, 'real_bond_10y', demo_vars + controls,
                  "M4a: Safe → real 10y", demo_vars + controls)
    if r: all_results.append(r)

    # M4b: Z → 10y bond, non-safe only
    r = run_model(nonsafe, 'real_bond_10y', demo_vars + controls,
                  "M4b: Non-safe → real 10y", demo_vars + controls)
    if r: all_results.append(r)

    # M4c: Z × safe_issuer interaction (full sample)
    safe_interact_vars = ['Z_1_x_safe', 'Z_2_x_safe', 'Z_3_x_safe']
    safe_interact_avail = [v for v in safe_interact_vars if v in df.columns]
    if safe_interact_avail:
        r = run_model(df, 'real_bond_10y',
                      demo_vars + controls + ['safe_issuer'] + safe_interact_avail,
                      "M4c: Z×safe → real 10y",
                      demo_vars + controls + ['safe_issuer'] + safe_interact_avail)
        if r: all_results.append(r)

    # ================================================================
    # SECTION C: Sovereign Spread / EMBI Proxy
    # ================================================================
    print("\n" + "─" * 60)
    print("C. Sovereign Spread Regressions")
    print("─" * 60)

    # M5a: Z → sovereign spread (real 10y vs world)
    r = run_model(df, 'sovereign_spread', demo_vars + controls,
                  "M5a: Z → sovereign spread", demo_vars + controls)
    if r: all_results.append(r)

    # M5b: Non-safe countries only (EMBI-like)
    r = run_model(nonsafe, 'sovereign_spread', demo_vars + controls,
                  "M5b: Non-safe Z → spread", demo_vars + controls)
    if r: all_results.append(r)

    # M5c: Z → domestic spread (lending - policy rate)
    if 'domestic_spread' in df.columns:
        r = run_model(df, 'domestic_spread', demo_vars + controls,
                      "M5c: Z → domestic spread", demo_vars + controls)
        if r: all_results.append(r)

    # ================================================================
    # SECTION D: Convenience Yield Proxy
    # ================================================================
    print("\n" + "─" * 60)
    print("D. Convenience Yield Proxy")
    print("─" * 60)

    # M6a: Z → lending-govt spread (inverse convenience yield)
    if 'lending_govt_spread' in df.columns:
        r = run_model(df, 'lending_govt_spread', demo_vars + controls,
                      "M6a: Z → lending-govt spread", demo_vars + controls)
        if r: all_results.append(r)

    # M6b: Safe issuers only
    if 'lending_govt_spread' in df.columns:
        r = run_model(safe, 'lending_govt_spread', demo_vars + controls,
                      "M6b: Safe Z → lend-govt spread", demo_vars + controls)
        if r: all_results.append(r)

    # ================================================================
    # SECTION E: OECD vs Non-OECD
    # ================================================================
    print("\n" + "─" * 60)
    print("E. OECD vs Non-OECD Subsamples")
    print("─" * 60)

    oecd = df[df['iso3'].isin(OECD_38)].copy()
    non_oecd = df[~df['iso3'].isin(OECD_38)].copy()

    r = run_model(oecd, 'real_bond_10y', demo_vars + controls,
                  "M7a: OECD Z → real 10y", demo_vars + controls)
    if r: all_results.append(r)

    r = run_model(non_oecd, 'real_bond_10y', demo_vars + controls,
                  "M7b: Non-OECD Z → real 10y", demo_vars + controls)
    if r: all_results.append(r)

    r = run_model(oecd, 'sovereign_spread', demo_vars + controls,
                  "M7c: OECD Z → spread", demo_vars + controls)
    if r: all_results.append(r)

    # ================================================================
    # SECTION F: Age Decomposition
    # ================================================================
    print("\n" + "─" * 60)
    print("F. Age Decomposition on Rate Components")
    print("─" * 60)

    age_vars = ['old_dep', 'youth_dep']

    r = run_model(df, 'real_bond_10y', age_vars + controls,
                  "M8a: age → real 10y", age_vars + controls)
    if r: all_results.append(r)

    if 'sovereign_spread' in df.columns:
        r = run_model(df, 'sovereign_spread', age_vars + controls,
                      "M8b: age → sovereign spread", age_vars + controls)
        if r: all_results.append(r)

    if 'lending_govt_spread' in df.columns:
        r = run_model(df, 'lending_govt_spread', age_vars + controls,
                      "M8c: age → lend-govt spread", age_vars + controls)
        if r: all_results.append(r)

    # ================================================================
    # SECTION G: Global Safe Supply as Regressor
    # ================================================================
    print("\n" + "─" * 60)
    print("G. Global Safe Supply Controls")
    print("─" * 60)

    if 'safe_supply_ratio' in df.columns:
        # Does Z survive controlling for safe supply?
        supply_controls = controls + ['safe_supply_ratio']
        r = run_model(df, 'real_bond_10y', demo_vars + supply_controls,
                      "M9a: Z → 10y + safe_supply", demo_vars + supply_controls)
        if r: all_results.append(r)

        # Does safe supply itself predict rates?
        r = run_model(df, 'real_bond_10y', ['safe_supply_ratio'] + controls,
                      "M9b: safe_supply → 10y", ['safe_supply_ratio'] + controls)
        if r: all_results.append(r)

    # ── Build results table ──
    print("\n\nBuilding results table ...")
    build_table(all_results)

    print("\n" + "=" * 70)
    print("Phase 2 complete.")
    print("=" * 70)


def build_table(all_results):
    """Save markdown results table."""
    if not all_results:
        print("  No results to tabulate.")
        return

    key_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'safe_issuer', 'Z_1_x_safe', 'Z_2_x_safe', 'Z_3_x_safe',
                'safe_supply_ratio']

    md = ["# Table 2: Rate Decomposition Results\n"]

    # Summary table
    md.append("## Model Summary\n")
    md.append("| Model | Dep Var | N | Countries | R² | ρ |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} | {r['rho']:.3f} |")

    # Key coefficients
    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")

    md.append(f"\n*Controls: {', '.join(CONTROLS)}*")
    md.append("*PanelGLS with AR(1) correction, no fixed effects.*")
    md.append("*Safe issuer = S&P rating AA- or above (time-varying).*")

    # ── Safe vs Non-Safe comparison panel ──
    md.append("\n## Safe vs Non-Safe Rate Comparison\n")
    safe_labels = [r for r in all_results if 'Safe' in r['label'] or 'safe' in r['label']]
    if safe_labels:
        md.append("| Subsample | Z₁ Coef | Z₁ p | R² | N |")
        md.append("|-----------|---------|------|-----|---|")
        for r in all_results:
            if 'coef_Z_1' in r:
                p = r['p_Z_1']
                md.append(f"| {r['label']} | {r['coef_Z_1']:.2f}{stars(p)} "
                          f"| {p:.4f} | {r['r_squared']:.3f} | {r['n_obs']} |")

    out_path = TABLES_DIR / "table2_rate_decomposition.md"
    out_path.write_text('\n'.join(md))
    print(f"  Saved: {out_path}")


if __name__ == "__main__":
    main()
